/**
 * Copyright 2019-2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "graph_optimizer/op_judge/imply_type/op_impl_type_judge.h"
#include "graph/debug/ge_attr_define.h"

namespace fe {
OpImplTypeJudge::OpImplTypeJudge(const std::string& engine_name, FEOpsKernelInfoStorePtr fe_ops_kernel_info_store_ptr)
    : OpJudgeBase(engine_name), ops_kernel_info_store_ptr_(fe_ops_kernel_info_store_ptr) {}
OpImplTypeJudge::~OpImplTypeJudge() {}

/*
 *  @ingroup fe
 *  @brief   set the highest prior imply type of op,
 *           update data type and format of op
 *  @param   [in|out] graph  compute graph
 *  @return  SUCCESS or FAILED
 */
Status OpImplTypeJudge::Judge(ge::ComputeGraph& graph) {
  // set the highest prior imply type of op
  FE_TIMECOST_START(OpImplTypeJudge);
  FE_CHECK_NOTNULL(ops_kernel_info_store_ptr_);
  for (auto& node : graph.GetAllNodes()) {
    Status result = JudgeByNode(node);
    if (result != SUCCESS) {
      return result;
    }
  }
  FE_TIMECOST_END(OpImplTypeJudge, "OpImplTypeJudge during FEGraphOptimizer::OptimizeOriginalGraph");
  return SUCCESS;
}

Status OpImplTypeJudge::SetCoreType(ge::OpDescPtr op_desc_ptr) {
  std::string core_type = ops_kernel_info_store_ptr_->GetFEOpsKernelInfoStoreName();
  if (ge::AttrUtils::SetStr(op_desc_ptr, CORE_TYPE, core_type)) {
    return SUCCESS;
  } else {
    FE_LOGW("Set attr %s failed! Engine name is %s.", CORE_TYPE.c_str(), core_type.c_str());
    return FAILED;
  }
}

Status OpImplTypeJudge::JudgeByNode(ge::NodePtr node_ptr) {
  // 1. check the op_type
  FE_CHECK_NOTNULL(node_ptr);
  string op_type = node_ptr->GetType();
  if (IsPlaceOrEnd(op_type)) {
    return SUCCESS;
  }

  // 2. check the attr of op_desc_ptr
  ge::OpDescPtr op_desc_ptr = node_ptr->GetOpDesc();
  FE_CHECK_NOTNULL(op_desc_ptr);
  if (SetCoreType(op_desc_ptr) != SUCCESS) {
    return OP_JUDGE_SET_CORE_TYPE_FAILED;
  }
  int64_t is_check_supported = 0;
  if (ge::AttrUtils::GetInt(op_desc_ptr, IS_CHECK_SUPPORTED, is_check_supported)) {
    std::string supported_flag = "not supported";
    uint64_t supported_value = is_check_supported;
    if ((supported_value & NOT_SUPPORTED_FLAG_BIT) == 0) {
      supported_flag = "supported";
    }
    FE_LOGD("Op[name=%s,type=%s]: the op has been check_supported, the result is %s.",
            op_desc_ptr->GetName().c_str(), op_desc_ptr->GetType().c_str(), supported_flag.c_str());
    return SUCCESS;
  }
  if (ge::AttrUtils::HasAttr(op_desc_ptr, FE_IMPLY_TYPE)) {
    return SUCCESS;
  }

  // 3. set the imply type of op
  OpImplType impl_type = EN_RESERVED;
  return SetOpImplType(op_desc_ptr, impl_type);
}

Status OpImplTypeJudge::SetOpImplType(ge::OpDescPtr op_desc_ptr, OpImplType& imply_type) {
  string op_name = op_desc_ptr->GetName().c_str();
  string op_type = op_desc_ptr->GetType().c_str();
  // 1. query the imply_type
  if (ops_kernel_info_store_ptr_->QueryHighPrioOpImplType(op_desc_ptr, imply_type) != SUCCESS) {
    FE_LOGD("Op[name=%s,type=%s]: the op is not supported by the op info lib.", op_name.c_str(), op_type.c_str());
    return SUCCESS;
  }

  // 2. check the imply_type
  auto iter = IMPL_TYPE_MAP.find(imply_type);
  if (iter == IMPL_TYPE_MAP.end()) {
    REPORT_FE_ERROR("[GraphOpt][OPImplJdg][CheckImplType][Op name=%s,type=%s]: the FE imply type [%d] is invalid.",
                    op_name.c_str(), op_type.c_str(), imply_type);
    return OP_JUDGE_MAP_KEY_FIND_FAILED;
  }

  // 3. set the fe and ge imply type of the op
  ge::AttrUtils::SetInt(op_desc_ptr, FE_IMPLY_TYPE, static_cast<int>(imply_type));
  bool is_ge_op = false;
  if (!ge::AttrUtils::GetBool(op_desc_ptr, IS_GE_OP, is_ge_op) || !is_ge_op) {
    ge::AttrUtils::SetInt(op_desc_ptr, ge::ATTR_NAME_IMPLY_TYPE, static_cast<int>(iter->second));
  }
  FE_LOGD("Op[name=%s,type=%s]: set the FE_IMPLY_TYPE attribute [%s], set the IMPLY_TYPE attribute [%s].",
      op_name.c_str(), op_type.c_str(), GetImplTypeString(imply_type).c_str(),
      GetGeImplTypeString(iter->second).c_str());
  return SUCCESS;
}

Status OpImplTypeJudge::JudgeInSubGraph(ge::ComputeGraph& graph) {
  FE_CHECK_NOTNULL(ops_kernel_info_store_ptr_);
  for (auto &node : graph.GetDirectNode()) {
    Status result = JudgeInSubGraphByNode(node);
    if (result != SUCCESS) {
      return result;
    }
  }
  return SUCCESS;
}

Status OpImplTypeJudge::JudgeInSubGraphByNode(ge::NodePtr node_ptr) {
  // 1. check the op_type
  FE_CHECK_NOTNULL(node_ptr);
  string op_type = node_ptr->GetType();
  if (IsPlaceOrEnd(op_type)) {
    return SUCCESS;
  }

  // 2. check the imply_type
  ge::OpDescPtr op_desc_ptr = node_ptr->GetOpDesc();
  FE_CHECK_NOTNULL(op_desc_ptr);
  if (ge::AttrUtils::HasAttr(op_desc_ptr, FE_IMPLY_TYPE)) {
    return SUCCESS;
  }

  // 3. set the imply type of op
  OpImplType impl_type = EN_RESERVED;
  return SetOpImplType(op_desc_ptr, impl_type);
}
}  // namespace fe
